# -*- coding:utf-8 -*- 
# Call module: call functions with ropchains 

from ropgenerator.IO import string_special, string_bold, banner, error, string_ropg
from ropgenerator.Constraints import Constraint, Assertion, BadBytes, RegsNotModified
from ropgenerator.Load import loadedBinary
from ropgenerator.exploit.Scanner import getFunctionAddress, getAllFunctions
from ropgenerator.semantic.ROPChains import ROPChain, validAddrStr
from ropgenerator.semantic.Engine import search, getBaseAssertion
from ropgenerator.Database import QueryType
from ropgenerator.CommonUtils import set_offset, reset_offset
from ropgenerator.exploit.HighLevelUtils import build_call
from ropgenerator.exploit.Utils import parseFunction, parse_bad_bytes, parse_keep_regs
import ropgenerator.Architecture as Arch
import sys

###################
#  CALL COMMAND   # 
###################

OPTION_OUTPUT = '--output-format'
OPTION_OUTPUT_SHORT = '-f'
# Options for output
OUTPUT_CONSOLE = 'console'
OUTPUT_PYTHON = 'python'
OUTPUT_RAW = 'raw'
OUTPUT = None # The one choosen 

OPTION_BAD_BYTES = '--bad-bytes'
OPTION_BAD_BYTES_SHORT = '-b'

OPTION_CALL = "--call"
OPTION_CALL_SHORT = "-c"

OPTION_KEEP_REGS = '--keep-regs'
OPTION_KEEP_REGS_SHORT = '-k'

OPTION_LIST = "--list"
OPTION_LIST_SHORT = "-l"

OPTION_HELP = "--help"
OPTION_HELP_SHORT = "-h"

OPTION_SHORTEST = '--shortest'
OPTION_SHORTEST_SHORT = '-s'

OPTION_LMAX = '--max-length'
OPTION_LMAX_SHORT = '-m'

OPTION_OFFSET = '--offset'
OPTION_OFFSET_SHORT = '-off'


CMD_CALL_HELP =  banner([string_bold("'call' command"),\
                    string_special("(Call functions with ROPChains)")])
CMD_CALL_HELP += "\n\n\t"+string_bold("Usage:")+\
"\n\t\tcall [OPTIONS]"
CMD_CALL_HELP += "\n\n\t"+string_bold("Options")+":"
CMD_CALL_HELP += "\n\t\t"+string_special(OPTION_CALL_SHORT)+","+\
    string_special(OPTION_CALL)+" <function>\t Call a function"
CMD_CALL_HELP += "\n\n\t\t"+string_special(OPTION_BAD_BYTES_SHORT)+","+string_special(OPTION_BAD_BYTES)+" <bytes>\t Bad bytes for payload.\n\t\t\t\t\t Expected format is a list of bytes \n\t\t\t\t\t separated by comas (e.g '-b 0A,0B,2F')"
CMD_CALL_HELP += "\n\n\t\t"+string_special(OPTION_KEEP_REGS_SHORT)+","+string_special(OPTION_KEEP_REGS)+" <regs>\t Registers that shouldn't be modified.\n\t\t\t\t\t Expected format is a list of registers \n\t\t\t\t\t separated by comas (e.g '-k edi,eax')"
CMD_CALL_HELP += "\n\n\t\t"+string_special(OPTION_OFFSET_SHORT)+","+string_special(OPTION_OFFSET)+" <int>\t Offset to add to gadget addresses" 
CMD_CALL_HELP += "\n\n\t\t"+string_special(OPTION_LMAX_SHORT)+","+string_special(OPTION_LMAX)+" <int>\t Max length of the ROPChain in bytes"
CMD_CALL_HELP += "\n\n\t\t"+string_special(OPTION_SHORTEST_SHORT)+","+string_special(OPTION_SHORTEST)+"\t\t Find the shortest matching ROP-Chains"
CMD_CALL_HELP += "\n\n\t\t"+string_special(OPTION_LIST_SHORT)+","+\
    string_special(OPTION_LIST)+"\t\t List available functions"
CMD_CALL_HELP += "\n\n\t\t"+string_special(OPTION_OUTPUT_SHORT)+","+\
    string_special(OPTION_OUTPUT)+\
    " <fmt> Output format for ropchains.\n\t\t\t\t\t Expected format is one of the\n\t\t\t\t\t following: "+\
    string_special(OUTPUT_CONSOLE)+','+string_special(OUTPUT_PYTHON)
    
CMD_CALL_HELP += "\n\n\t\t"+string_special(OPTION_HELP_SHORT)+","+string_special(OPTION_HELP)+"\t\t Show this help"
CMD_CALL_HELP += "\n\n\t"+string_bold("Function format")+": "+\
    string_special("function")+"( "+string_special("arg1")+","+string_special(" ...")+\
    "," + string_special(" argN")+")"
CMD_CALL_HELP += "\n\n\t"+string_bold("Examples")+": "+\
    "\n\t\tcall strcpy(0x123456, 0x987654)"
    
def print_help():
    print(CMD_CALL_HELP)


def call(args):
    global OUTPUT, OUTPUT_CONSOLE, OUTPUT_PYTHON
    # Parsing arguments
    if( not args):
        print_help()
        return
    OUTPUT = OUTPUT_CONSOLE
    funcName = None
    i = 0
    seenOutput = False
    seenFunction = False
    seenBadBytes = False
    seenKeepRegs = False
    seenShortest = False
    seenLmax = False
    seenOffset = False
    clmax = None
    offset = 0 
    
    constraint = Constraint()
    assertion = getBaseAssertion()
    while i < len(args):
        if( args[i] in [OPTION_LIST, OPTION_LIST_SHORT]):
            func_list = getAllFunctions()
            print_functions(func_list)
            return 
        if( args[i] in [OPTION_BAD_BYTES, OPTION_BAD_BYTES_SHORT]):
            if( seenBadBytes ):
                error("Error. '" + args[i] + "' option should be used only once")
                return 
            if( i+1 >= len(args)):
                error("Error. Missing bad bytes after option '"+args[i]+"'")
                return 
            seenBadBytes = True
            (success, res) = parse_bad_bytes(args[i+1])
            if( not success ):
                error(res)
                return
            i = i+2
            constraint = constraint.add(BadBytes(res))
        elif( args[i] in [OPTION_KEEP_REGS, OPTION_KEEP_REGS_SHORT]):
            if( seenKeepRegs ):
                error("Error. '" + args[i] + "' option should be used only once")
                return 
            if( i+1 >= len(args)):
                error("Error. Missing register after option '"+args[i]+"'")
                return 
            seenKeepRegs = True
            (success, res) = parse_keep_regs(args[i+1])
            if( not success ):
                error(res)
                return
            i = i+2
            constraint = constraint.add(RegsNotModified(res))
        elif( args[i] in [OPTION_CALL, OPTION_CALL_SHORT] ):
            if( not loadedBinary() ):
                error("Error. You should load a binary before building ROPChains")
                return 
            elif( seenFunction ):
                error("Option '{}' should be used only once".format(args[i]))
                return  
            userInput = ''
            i +=1
            while( i < len(args) and args[i][0] != "-"):
                userInput += args[i]
                i += 1
            (funcName, funcArgs ) = parseFunction(userInput)
            if( not funcName):
                return 
            seenFunction = True
        elif( args[i] in [OPTION_OUTPUT, OPTION_OUTPUT_SHORT]):
            if( seenOutput ):
                error("Option '{}' should be used only once".format(args[i]))
                return 
            if( i+1 >= len(args)):
                error("Error. Missing output format after option '"+args[i]+"'")
                return 
            if( args[i+1] in [OUTPUT_CONSOLE, OUTPUT_PYTHON]):
                OUTPUT = args[i+1]
                seenOutput = True
                i += 2
            else:
                error("Error. Unknown output format: {}".format(args[i+1]))
                return 
        elif( args[i] in [OPTION_SHORTEST, OPTION_SHORTEST_SHORT]):
            if( seenShortest ):
                error("Option '{}' should be used only once".format(args[i]))
                return 
            seenShortest = True
            i += 1
        elif( args[i] == OPTION_LMAX or args[i] == OPTION_LMAX_SHORT ):
            if( seenLmax ):
                error("Option '{}' should be used only once".format(args[i]))
                return 
            if( i+1 >= len(args)):
                error("Error. Missing length after option '"+args[i]+"'")
                return 
            try:
                clmax = int(args[i+1])
                if( clmax < Arch.octets() ):
                    raise Exception()
                # Convert number of bytes into number of ropchain elements
                clmax /= Arch.octets()
            except:
                error("Error. '" + args[i+1] +"' bytes is not valid")
                return 
            i += 2 
            seenLmax = True
        elif( args[i] == OPTION_OFFSET or args[i] == OPTION_OFFSET_SHORT ):
            if( seenOffset ):
                error("Option '{}' should be used only once".format(args[i]))
                return 
            if( i+1 >= len(args)):
                error("Error. Missing offset after option '"+args[i]+"'")
                return 
            try:
                offset = int(args[i+1])
                if( offset < 0 ):
                    raise Exception()
            except:
                try:
                    offset = int(args[i+1], 16)
                    if( offset < 0 ):
                        raise Exception()
                except: 
                    error("Error. '" + args[i+1] +"' is not a valid offset")
                    return 
            i += 2 
            seenOffset = True
        elif( args[i] in [OPTION_HELP, OPTION_HELP_SHORT]):
            print_help()
            return 
        else:
            error("Error. Unknown option '{}'".format(args[i]))
            return 
    
    if( not funcName ):
        error("Missing function to call")
    else:
        # Set offset 
        if( not set_offset(offset) ):
            error("Error. Your offset is too big :'(")
            return 
        try:
            res = build_call(funcName, funcArgs, constraint, assertion, clmax=clmax, optimizeLen=seenShortest)
            if( isinstance(res, str) ):
                error(res)
            else:
                print_chains([res], "Built matching ROPChain", constraint.getBadBytes())
        except Exception, e:
            reset_offset()
            raise sys.exc_info()[1], None, sys.exc_info()[2]
        # Reset offset 
        reset_offset()
            



##########################
# Pretty print functions #
##########################

def print_chains(chainList, msg, badBytes=[]):
    global OUTPUT
    sep = "------------------"
    if( chainList):
        print(string_bold('\n\t'+msg))
        if( OUTPUT == OUTPUT_CONSOLE ):
            print("\n"+chainList[0].strConsole(Arch.currentArch.bits, badBytes))
        elif( OUTPUT == OUTPUT_PYTHON ):
            print('\n' + chainList[0].strPython(Arch.currentArch.bits, badBytes))
        for chain in chainList[1:]:
            if( OUTPUT == OUTPUT_CONSOLE ):
                print('\t'+sep + "\n"+ chain.strConsole(Arch.currentArch.bits, badBytes))
            elif( OUTPUT == OUTPUT_PYTHON ):
                print('\t'+sep + '\n' + chain.strPython(Arch.currentArch.bits, badBytes))
    else:
        print(string_bold("\n\tNo matching ROPChain found"))

def print_functions(func_list):
    """
    func_list - list of pairs (str, int) = (funcName, funcAddress)
    """
    space = 28
    print(banner([string_bold("Available functions")]))
    print("\tFunction" + " "*(space-8) + "Address")
    print("\t------------------------------------")
    for (funcName, funcAddr) in sorted(func_list, key=lambda x:x[0] ):
        space2 = space - len(funcName)
        if( space2 < 0 ):
            space2 = 2
        print("\t"+string_special(funcName)+ " "*space2 + hex(funcAddr))
    print("")
    
